# ------------------------------------------------------------------------------
# Fits 5 LASSO models (one excluding each fold)
# Updates author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=10000]" bash 02_fit_lasso.sh {outcome} {split} {restriction}
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(glue)
library(Matrix)
library(doParallel)
library(glmnet)
library(optparse)
library(here)
library(testit)
library(tidyverse)

u <- modules::use(here("lib", "util.R"))

# Command Line Args ------------------------------------------------------------
arg_config <- list(
  make_option("--outcome", type = "character"),
  make_option("--split", type = "character"),
  make_option("--restriction", type = "character")
)
arg_parser <- OptionParser(option_list = arg_config)
arg_list <- parse_args(arg_parser)
split <- arg_list$split

# Directories ------------------------------------------------------------------
message("Establishing directories...")
paths <- read_yaml(here("lib", "filepaths.yml"))
cohort_dir <- file.path(paths$modeling$oos, "cohorts", arg_list$split)
features_dir <- file.path(paths$features$dir, arg_list$split)
save_dir <- file.path(
  paths$modeling$oos, "models", arg_list$split, arg_list$restriction
)
assert("Models directory exists", dir.exists(save_dir))

build_models_dir <- here::here("code", "03_analysis", "01_build_models")

# Load Data --------------------------------------------------------------------
message("Loading data...")
config <- read_yaml(
  file.path(build_models_dir, "model_config", glue("{arg_list$outcome}.yml"))
)
ds <- config$downsample
ids <- readRDS(file.path(cohort_dir, glue("train_cohort_{ds}.rds")))
x <- readRDS(file.path(features_dir, glue("train_features_{ds}.rds")))

if(arg_list$restriction == "dropcc"){
  # drop chief complaint features
  keep_feats <- which(!grepl("ed_enc_t0d", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "justcc"){
  keep_feats <- which(grepl("ed_enc_t0d", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "dem"){
  keep_feats <- which(grepl("dem_", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "enc"){
  keep_feats <- which(grepl("enc_", colnames(x)) & !grepl("_cc_", colnames(x)))
  x <- x[, keep_feats]
}

# Subset Data ------------------------------------------------------------------
message("Subsetting data...")
keep_pop <- switch(config$population,
  all = rep(TRUE, nrow(ids)),
  tested = ids$test_010_day == TRUE,
  untested = ids$test_010_day == FALSE
)

keep_pop <- keep_pop & !ids$exclude_modeling
keep <- which(keep_pop)

x <- x[keep, ]
ids <- ids[keep, ]

# Tuning -----------------------------------------------------------------------
message("Fitting lassos...")
for (i in 1:5) {
  message("Running model excluding fold ", i, "...")
  keep <- which(ids$train_fold != i & !(ids$in_ensemble))
  ids_sub <- ids[keep, ]
  x_sub <- x[keep, ]
  print(nrow(ids_sub))
  print(nrow(x_sub))

  registerDoParallel(cores = uniqueN(ids_sub$train_fold))
  # fix folds to be 1,2,3,4
  ids_sub <- ids_sub %>%
    mutate(train_fold = ifelse(train_fold > i, train_fold - 1, train_fold))

  tuning_result <- cv.glmnet(
    x = x_sub,
    y = ids_sub[[config$target]],
    foldid = ids_sub$train_fold,
    parallel = TRUE,
    family = "binomial",
    type.measure = "auc",
    alpha = 1
  )

  stopImplicitCluster()

  result_path <- file.path(save_dir, glue("lasso__{target}__{population}__{i}.rds",
    .envir = config
  ))
  message("Saving fold ", i, " to ", result_path, "...")
  saveRDS(tuning_result, result_path, compress = FALSE)
}

# Done -------------------------------------------------------------------------
message("Done.")
